import java.util.ArrayList;

public class OctTree {
    private OctTreeNode head;
    private double dimension;

    public OctTree() {
        dimension = 4 * Simulation.AU;
    }

    public OctTree(double dimension) {
        this.dimension = dimension;
    }

    // create tree from body array
    public void construct(ArrayList<CelestialBody> bodies) {
        for (CelestialBody body : bodies) {
            this.add(body);
        }
    }

    public void add(CelestialBody body) {
        if (head == null) {
            this.head = new OctTreeNode(dimension, new Vector3(0, 0, 0));
        }
        if (Math.abs(body.getPosition().getX()) > 2*Simulation.AU) {
            return;
        }
        if (Math.abs(body.getPosition().getY()) > 2*Simulation.AU) {
            return;
        }
        if (Math.abs(body.getPosition().getZ()) > 2*Simulation.AU) {
            return;
        }
        this.head.add(body);
    }

    // calculate force for each body in the array passed
    public void force(ArrayList<CelestialBody> bodies) {
        ArrayList<Vector3> force = new ArrayList<>();
        for (CelestialBody body : bodies) {
            force.add(head.force(body));
        }
        for (int i = 0; i < bodies.size(); i++) {
            bodies.get(i).move(force.get(i));
        }
    }

    public void draw() {
        this.head.draw();
    }


    private class OctTreeNode {
        private CelestialBody body;
        private double massTotal;
        private Vector3 massCenter;
        private double dimension;
        private OctTreeNode[] sub;
        private Vector3 position;

        public OctTreeNode(CelestialBody body, double dimension, Vector3 position) {
            this.body = body;
            this.massTotal = body.getMass();
            this.massCenter = body.getPosition();
            this.dimension = dimension;
            this.position = position;
        }

        public OctTreeNode(double dimension, Vector3 position) {
            this.dimension = dimension;
            this.position = position;
        }

        public Vector3 force(CelestialBody body) {
            Vector3 forceOnBody = new Vector3(0, 0, 0);

            // no force on itself
            if (body == this.body) {
                return forceOnBody;
            }

            if (sub == null) {
                return body.gravitationalForce(this.body);
            }
            // if group is far away, calculate force for group
            boolean isFarAway = (body.getPosition().distanceTo(massCenter) / dimension) > Simulation.T;
            if (isFarAway) {
                return body.gravitationalForce(massCenter, massTotal);
            }

            // if group too close, get children and calculate force foreach childnode
            for (OctTreeNode child : sub) {
                if (child != null) {
                    forceOnBody = forceOnBody.plus(child.force(body));
                }
            }

            return forceOnBody;
        }

        public void add(CelestialBody body) {
            // first node in tree
            if (this.body == null && this.sub == null) {
                this.body = body;
                this.massTotal = body.getMass();
                this.massCenter = body.getPosition();
                return;
            }
            // if no subnodes exist, push this to subnode
            if (this.sub == null) {
                sub = new OctTreeNode[8];
                // add this.body to subnode, set this.body null
                int i = this.addToSubNode(this.body.getPosition().getX(), this.body.getPosition().getY(), this.body.getPosition().getZ());
                Vector3 subPosition = this.subPosition(i);
                sub[i] = new OctTreeNode(this.body, dimension / 2, subPosition);
                this.body = null;
            }

            // push body to subnode
            int i = this.addToSubNode(body.getPosition().getX(), body.getPosition().getY(), body.getPosition().getZ());
            if (sub[i] == null) {
                Vector3 subPosition = this.subPosition(i);
                sub[i] = new OctTreeNode(body, dimension / 2, subPosition);
            } else {
                sub[i].add(body);
            }

            // update massCenter and massTotal
            massCenter = massCenter.times(massTotal).plus(body.getPosition().times(body.getMass()));
            this.massTotal += body.getMass();
            massCenter = massCenter.times(1 / massTotal);

        }

        // determine index of subnode
        private int addToSubNode(double x, double y, double z) {
            if (x <= position.getX() && y > position.getY() && z <= position.getZ()) {
                return 0;
            }
            if (x > position.getX() && y > position.getY() && z <= position.getZ()) {
                return 1;
            }
            if (x <= position.getX() && y <= position.getY() && z <= position.getZ()) {
                return 2;
            }
            if (x > position.getX() && y <= position.getY() && z <= position.getZ()) {
                return 3;
            }
            if (x <= position.getX() && y > position.getY() && z > position.getZ()) {
                return 4;
            }
            if (x > position.getX() && y > position.getY() && z > position.getZ()) {
                return 5;
            }
            if (x <= position.getX() && y <= position.getY() && z > position.getZ()) {
                return 6;
            }
            if (x > position.getX() && y <= position.getY() && z > position.getZ()) {
                return 7;
            }
            throw new Error("Invalid position");
        }

        // define center of subnode (!= masscenter)
        private Vector3 subPosition(int i) {
            Vector3 subPosition = new Vector3(0, 0, 0);
            double scale = (dimension / 4);
            switch (i) {
                case 0:
                    subPosition = new Vector3(position.getX() - scale, position.getY() + scale, position.getZ() - scale);
                    break;
                case 1:
                    subPosition = new Vector3(position.getX() + scale, position.getY() + scale, position.getZ() - scale);
                    break;
                case 2:
                    subPosition = new Vector3(position.getX() - scale, position.getY() - scale, position.getZ() - scale);
                    break;
                case 3:
                    subPosition = new Vector3(position.getX() + scale, position.getY() - scale, position.getZ() - scale);
                    break;
                case 4:
                    subPosition = new Vector3(position.getX() - scale, position.getY() + scale, position.getZ() + scale);
                    break;
                case 5:
                    subPosition = new Vector3(position.getX() + scale, position.getY() + scale, position.getZ() + scale);
                    break;
                case 6:
                    subPosition = new Vector3(position.getX() - scale, position.getY() - scale, position.getZ() + scale);
                    break;
                case 7:
                    subPosition = new Vector3(position.getX() + scale, position.getY() - scale, position.getZ() + scale);
                    break;
            }
            return subPosition;

        }

        public void draw() {
            if (sub == null) {
                body.draw();
//                double xL = this.position.getX() - dimension / 2;
//                double xR = this.position.getX() + dimension / 2;
//                double yU = this.position.getY() - dimension / 2;
//                double yO = this.position.getY() + dimension / 2;
//                StdDraw.line(xL, yU, xR, yU);
//                StdDraw.line(xR, yU, xR, yO);
//                StdDraw.line(xR, yO, xL, yO);
//                StdDraw.line(xL, yO, xL, yU);
            } else {
                for (int i = 0; i < 8; i++) {
                    if (sub[i] != null) {
                        sub[i].draw();
                    }
                }
            }
        }

    }
}
